import argparse
import math
import time
import dill as pickle
# import pickle
# import pickle5 as pickle
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torch.optim as optim
# from torchtext.data import Field, Dataset, BucketIterator
# from torchtext.datasets import TranslationDataset

import random
import numpy as np
# fix random seed
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import time
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import importlib


class Log:
    state_dict = {}

    def __init__(self, root_path, method_name):
        self.root_path = root_path
        self.method_name = method_name
        self._time = str(time.strftime("%Y-%m-%d/%H-%M-%S", time.localtime()))
        self._log_path = os.path.join(self.root_path, self.method_name, self._time)
        if not os.path.exists(self._log_path):
            os.makedirs(self._log_path)
            print('log_path:', self._log_path)

    def record_opt(self, opt):
        self._opt_path = self._log_path + '/opt.txt'
        f = open(self._opt_path, 'w')
        f.write(str(opt))

        f.close()

    def state_dict_update(self, key_value_list):
        for key, value in key_value_list:
            if key not in self.state_dict:
                self.state_dict[key] = []
            self.state_dict[key].append(value)
        np.save(self._log_path + '/state_dict.npy', self.state_dict)

    def save_model(self, model_name, checkpoint):
        self._model_path = os.path.join(self._log_path, model_name)
        torch.save(checkpoint, self._model_path)

    def record_report(self, report_str):
        self._report_path = self._log_path + '/report.txt'
        f = open(self._report_path, 'a')
        f.writelines(report_str + '\n')
        f.close()
